{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from utils import *"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "[[chapter_callbacks]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Callbacks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Introduction to callbacks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since we now know how to create state-of-the-art architectures for computer vision, natural image processing, tabular analysis, and collaborative filtering, and we know how to train them quickly with accelerated optimisers, and we know how to regularise them effectively, we're done, right?\n",
    "\n",
    "Well… Yes, sort of. But other things come up. Sometimes you need to change how things work a little bit. In fact, we have already seen examples of this: mixup, FP16 training, resetting the model after each epoch for training RNNs, and so forth. How do we go about making these kinds of tweaks to the training process?\n",
    "\n",
    "We've seen the basic training loop, which, with the help of the `Optimizer` class, looks like this for a single epoch:\n",
    "\n",
    "```python\n",
    "for xb,yb in dl:\n",
    "    loss = loss_func(model(xb), yb)\n",
    "    loss.backward()\n",
    "    opt.step()\n",
    "    opt.zero_grad()\n",
    "```\n",
    "\n",
    "Here's one way to picture that:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img alt=\"Basic training loop\" width=\"300\" caption=\"Basic training loop\" id=\"basic_loop\" src=\"images/att_00048.png\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The usual way for deep learning practitioners to customise the training loop is to make a copy of an existing training loop, and then insert their code necessary for their particular changes into it. This is how nearly all code that you find online will look. But it has some very serious problems.\n",
    "\n",
    "It's not very likely that some particular tweaked training loop is going to meet your particular needs. There are hundreds of changes that can be made to a training loop, which means there are billions and billions of possible permutations. You can't just copy one tweak from a training loop here, another from a training loop there, and expect them all to work together. Each will be based on different assumptions about the environment that it's working in, use different naming conventions, and expect the data to be in different formats.\n",
    "\n",
    "We need a way to allow users to insert their own code at any part of the training loop, but in a consistent and well-defined way. Computer scientists have already come up with an answer to this question: the callback. A callback is a piece of code that you write, and inject into another piece of code at some predefined point. In fact, callbacks have been used with deep learning training loops for years. The problem is that only a small subset of places that may require code injection have been available in previous libraries, and, more importantly, callbacks were not able to do all the things they needed to do.\n",
    "\n",
    "In order to be just as flexible as manually copying and pasting a training loop and directly inserting code into it, a callback must be able to read every possible piece of information available in the training loop, modify all of it as needed, and fully control when a batch, epoch, or even all the whole training loop should be terminated. fastai is the first library to provide all of this functionality. It modifies the training loop so it looks like this:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img alt=\"Training loop with callbacks\" width=\"550\" caption=\"Training loop with callbacks\" id=\"cb_loop\" src=\"images/att_00049.png\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The real test of whether this works has been borne out over the last couple of years — it has turned out that every single new paper implemented, or use a request fulfilled, for modifying the training loop has successfully been achieved entirely by using the fastai callback system. The training loop itself has not required modifications. Here are just a few of the callbacks that have been added:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img alt=\"Some fastai callbacks\" width=\"500\" caption=\"Some fastai callbacks\" id=\"some_cbs\" src=\"images/att_00050.png\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The reason that this is important for all of us is that it means that whatever idea we have in our head, we can implement it. We need never dig into the source code of PyTorch or fastai and act together some one-off system to try out our ideas. And when we do implement our own callbacks to develop our own ideas, we know that they will work together with all of the other functionality provided by fastai – so we will get progress bars, mixed precision training, hyperparameter annealing, and so forth.\n",
    "\n",
    "Another advantage is that it makes it easy to gradually remove or add functionality and perform ablation studies. You just need to adjust the list of callbacks you pass along to your fit function."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As an example, here is the fastai source code that is run for each batch of the training loop:\n",
    "\n",
    "```python\n",
    "try:\n",
    "    self._split(b);                                  self('begin_batch')\n",
    "    self.pred = self.model(*self.xb);                self('after_pred')\n",
    "    self.loss = self.loss_func(self.pred, *self.yb); self('after_loss')\n",
    "    if not self.training: return\n",
    "    self.loss.backward();                            self('after_backward')\n",
    "    self.opt.step();                                 self('after_step')\n",
    "    self.opt.zero_grad()\n",
    "except CancelBatchException:                         self('after_cancel_batch')\n",
    "finally:                                             self('after_batch')\n",
    "```\n",
    "\n",
    "The calls of the form `self('...')` are where the callbacks are called. As you see, after every step a callback is called. The callback will receive the entire state of training, and can also modify it. For instance, as you see above, the input data and target labels are in `self.xb` and `self.yb` respectively. A callback can modify these to modify the data the training loop sees. It can also modify `self.loss`, or even modify the gradients."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating a callback"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The full list of available callback events is:\n",
    "\n",
    "- `begin_fit`: called before doing anything, ideal for initial setup.\n",
    "- `begin_epoch`: called at the beginning of each epoch, useful for any behavior you need to reset at each epoch.\n",
    "- `begin_train`: called at the beginning of the training part of an epoch.\n",
    "- `begin_batch`: called at the beginning of each batch, just after drawing said batch. It can be used to do any setup necessary for the batch (like hyper-parameter scheduling) or to change the input/target before it goes in the model (change of the input with techniques like mixup for instance).\n",
    "- `after_pred`: called after computing the output of the model on the batch. It can be used to change that output before it's fed to the loss.\n",
    "- `after_loss`: called after the loss has been computed, but before the backward pass. It can be used to add any penalty to the loss (AR or TAR in RNN training for instance).\n",
    "- `after_backward`: called after the backward pass, but before the update of the parameters. It can be used to do any change to the gradients before said update (gradient clipping for instance).\n",
    "- `after_step`: called after the step and before the gradients are zeroed.\n",
    "- `after_batch`: called at the end of a batch, for any clean-up before the next one.\n",
    "- `after_train`: called at the end of the training phase of an epoch.\n",
    "- `begin_validate`: called at the beginning of the validation phase of an epoch, useful for any setup needed specifically for validation.\n",
    "- `after_validate`: called at the end of the validation part of an epoch.\n",
    "- `after_epoch`: called at the end of an epoch, for any clean-up before the next one.\n",
    "- `after_fit`: called at the end of training, for final clean-up.\n",
    "\n",
    "This list is available as attributes of the special variable `event`; so just type `event.` and hit `Tab` in your notebook to see a list of all the options"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's take a look at an example. Do you recall how in <<chapter_nlp_dive>> we needed to ensure that our special `reset` method was called at the start of training and validation for each epoch? We used the `ModelReseter` callback provided by fastai to do this for us. But how did `ModelReseter` do that exactly? Here's the full actual source code to that class:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ModelReseter(Callback):\n",
    "    def begin_train(self):    self.model.reset()\n",
    "    def begin_validate(self): self.model.reset()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Yes, that's actually it! It just does what we said in the paragraph above: after completing training and epoch or validation for an epoch, call a method named `reset`.\n",
    "\n",
    "Callbacks are often \"short and sweet\" like this one. In fact, let's look at one more. Here's the fastai source for the callback that add RNN regularization (*AR* and *TAR*):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RNNRegularizer(Callback):\n",
    "    def __init__(self, alpha=0., beta=0.): self.alpha,self.beta = alpha,beta\n",
    "\n",
    "    def after_pred(self):\n",
    "        self.raw_out,self.out = self.pred[1],self.pred[2]\n",
    "        self.learn.pred = self.pred[0]\n",
    "\n",
    "    def after_loss(self):\n",
    "        if not self.training: return\n",
    "        if self.alpha != 0.:\n",
    "            self.learn.loss += self.alpha * self.out[-1].float().pow(2).mean()\n",
    "        if self.beta != 0.:\n",
    "            h = self.raw_out[-1]\n",
    "            if len(h)>1:\n",
    "                self.learn.loss += self.beta * (h[:,1:] - h[:,:-1]\n",
    "                                               ).float().pow(2).mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> stop: Go back to where we discussed TAR and AR regularization, and compare to the code here. Made sure you understand what it's doing, and why."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In both of these examples, notice how we can access attributes of the training loop by directly checking `self.model` or `self.pred`. That's because a `Callback` will always try to get an attribute it doesn't have inside the `Learner` associated to it. This is a shortcut for `self.learn.model` or `self.learn.pred`. Note that this shortcut works for reading attributes, but not for writing them, which is why when `RNNRegularizer` changes the loss or the predictions, you see `self.learn.loss = ` or `self.learn.pred = `. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When writing a callback, the following attributes of `Learner` are available:\n",
    "\n",
    "- `model`: the model used for training/validation\n",
    "- `data`: the underlying `DataLoaders`\n",
    "- `loss_func`: the loss function used\n",
    "- `opt`: the optimizer used to udpate the model parameters\n",
    "- `opt_func`: the function used to create the optimizer\n",
    "- `cbs`: the list containing all `Callback`s\n",
    "- `dl`: current `DataLoader` used for iteration\n",
    "- `x`/`xb`: last input drawn from `self.dl` (potentially modified by callbacks). `xb` is always a tuple (potentially with one element) and `x` is detuplified. You can only assign to `xb`.\n",
    "- `y`/`yb`: last target drawn from `self.dl` (potentially modified by callbacks). `yb` is always a tuple (potentially with one element) and `y` is detuplified. You can only assign to `yb`.\n",
    "- `pred`: last predictions from `self.model` (potentially modified by callbacks)\n",
    "- `loss`: last computed loss (potentially modified by callbacks)\n",
    "- `n_epoch`: the number of epochs in this training\n",
    "- `n_iter`: the number of iterations in the current `self.dl`\n",
    "- `epoch`: the current epoch index (from 0 to `n_epoch-1`)\n",
    "- `iter`: the current iteration index in `self.dl` (from 0 to `n_iter-1`)\n",
    "\n",
    "The following attributes are added by `TrainEvalCallback` and should be available unless you went out of your way to remove that callback:\n",
    "\n",
    "- `train_iter`: the number of training iterations done since the beginning of this training\n",
    "- `pct_train`: from 0. to 1., the percentage of training iterations completed\n",
    "- `training`:  flag to indicate if we're in training mode or not\n",
    "\n",
    "The following attribute is added by `Recorder` and should be available unless you went out of your way to remove that callback:\n",
    "\n",
    "- `smooth_loss`: an exponentially-averaged version of the training loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Callback ordering and exceptions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sometimes, callbacks need to be able to tell fastai to skip over a batch, or an epoch, or stop training altogether. For instance, consider `TerminateOnNaNCallback`. This handy callback will automatically stop training any time the loss becomes infinite or `NaN` (*not a number*). Here's the fastai source for this callback:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TerminateOnNaNCallback(Callback):\n",
    "    run_before=Recorder\n",
    "    def after_batch(self):\n",
    "        if torch.isinf(self.loss) or torch.isnan(self.loss):\n",
    "            raise CancelFitException"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The way it tells the training loop to interrupt training at this point is to `raise CancelFitException`. The training loop catches this exception and does not run any further training or validation. The callback control flow exceptions available are:\n",
    "\n",
    "- `CancelFitException`: Skip the rest of this batch and go to `after_batch\n",
    "- `CancelEpochException`: Skip the rest of the training part of the epoch and go to `after_train\n",
    "- `CancelTrainException`: Skip the rest of the validation part of the epoch and go to `after_validate\n",
    "- `CancelValidException`: Skip the rest of this epoch and go to `after_epoch\n",
    "- `CancelBatchException`: Interrupts training and go to `after_fit"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can detect one of those exceptions occurred and add code that executes right after with the following events:\n",
    "\n",
    "- `after_cancel_batch`: reached immediately after a `CancelBatchException` before proceeding to `after_batch`\n",
    "- `after_cancel_train`: reached immediately after a `CancelTrainException` before proceeding to `after_epoch`\n",
    "- `after_cancel_valid`: reached immediately after a `CancelValidException` before proceeding to `after_epoch`\n",
    "- `after_cancel_epoch`: reached immediately after a `CancelEpochException` before proceeding to `after_epoch`\n",
    "- `after_cancel_fit`: reached immediately after a `CancelFitException` before proceeding to `after_fit`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sometimes, callbacks need to be called in a particular order. In the case of `TerminateOnNaNCallback`, it's important that `Recorder` runs its `after_batch` after this callback, to avoid registering an NaN loss. You can specify `run_before` (this callback must run before ...) or `run_after` (this callback must run after ...) in your callback to ensure the ordering that you need."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have seen how to tweak the training loop of fastai to do anything we need, let's take a step back and dig a little bit deeper in the foundations of that training loop."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Questionnaire"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. What are the four steps of a training loop?\n",
    "1. Why is the use of callbacks better than writing a new training loop for each tweak you want to add?\n",
    "1. What are the necessary points in the design of the fastai's callback system that make it as flexible as copying and pasting bits of code?\n",
    "1. How can you get the list of events available to you when writing a callback?\n",
    "1. Write the `ModelResetter` callback (without peeking).\n",
    "1. How can you access the necessary attributes of the training loop inside a callback? When can you use or not use the shortcut that goes with it?\n",
    "1. How can a callback influence the control flow of the training loop.\n",
    "1. Write the `TerminateOnNaN` callback (without peeking if possible).\n",
    "1. How do you make sure your callback runs after or before another callback?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Further research"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. Look at the mixed precision callback with the documentation. Try to understand what each event and line of code does.\n",
    "1. Implement your own version of ther learning rate finder from scratch. Compare it with fastai's version.\n",
    "1. Look at the source code of the callbacks that ship with fastai. See if you can find one that's similar to what you're looking to do, to get some inspiration."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Foundations of Deep Learning: Wrap up"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Congratulations, you have made it to the end of the \"foundations of deep learning\" section. You now understand how all of fastai's applications and most important architectures are built, and the recommended ways to train them, and have all the information you need to build these from scratch. Whilst you probably won't need to create your own training loop, or batchnorm layer, for instance, knowing what is going on behind the scenes is very helpful for debugging, profiling, and deploying your solutions.\n",
    "\n",
    "Since you understand all of the foundations of fastai's applications now, be sure to spend some time digging through fastai's source notebooks, and running and experimenting with parts of them, since you can and see exactly how everything in fastai is developed.\n",
    "\n",
    "In the next section, we will be looking even further under the covers, to see how the actual forward and backward passes of a neural network are done, and we will see what tools are at our disposal to get better performance. We will then finish up with a project that brings together everything we have learned throughout the book, which we will use to build a method for interpreting convolutional neural networks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": false,
   "sideBar": true,
   "skip_h1_title": true,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
